# from d2l import torch as d2l
from torch import nn

from config import *
from transformer import TransformerEncoder, TransformerDecoder, EncoderDecoder
from attention import sequence_mask
from data import loaddata, MyDataset


class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    """The softmax cross-entropy loss with masks."""

    # `pred` shape: (`batch_size`, `num_steps`, `vocab_size`)
    # `label` shape: (`batch_size`, `num_steps`)
    # `valid_len` shape: (`batch_size`,)
    def forward(self, pred, label, valid_len):
        weights = torch.ones_like(label)  # 全1，大小同label
        weights = sequence_mask(weights, valid_len)  # mask处权重置为0
        self.reduction = 'none'
        # 格式必须为(C), (N,C) or (N,C,d1,d2,...) C为需要计算的损失
        # 所以把vocab_size维转到中间
        unweighted_loss = super().forward(pred.permute(0, 2, 1), label)
        # unweighted_loss, (batch, steps), 为每个词的损失
        # dim=1指去掉的维度为1，即对每个batch句子求平均
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss


def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device, first_train = True):
    """Train a model for sequence to sequence."""
    def xavier_init_weights(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if "weight" in param:
                    nn.init.xavier_uniform_(m._parameters[param])
    if first_train:
        net.apply(xavier_init_weights)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = MaskedSoftmaxCELoss()
    net.train()

    for epoch in range(num_epochs):
        Total = 0
        for batch in data_iter:
            optimizer.zero_grad()
            # 转device
            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]
            # inputs带<eos>不带<bos>，outputs带<bos>不带<eos>
            # 初始bos
            bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0],
                               device=device).reshape(-1, 1)
            dec_input = torch.cat([bos, Y[:, :-1]], 1)
            # Teacher forcing, 输入X和正确答案，希望得到正确答案
            # Y_hat
            Y_hat, _ = net(X, dec_input, X_valid_len)
            # Y_hat和Y中都有eos，没有bos
            l = loss(Y_hat, Y, Y_valid_len).sum()
            l.backward()  # Make the loss scalar for `backward`
            # 梯度裁剪
            # d2l.grad_clipping(net, 1)
            optimizer.step()
            Total += l.item()
        print(Total)



if __name__ == "__main__":
    dataset = torch.load(DATA_ROOT + "dataset{}".format(num_examples))
    src_vocab = torch.load(DATA_ROOT + "eng_vocab{}".format(num_examples))
    tgt_vocab = torch.load(DATA_ROOT + "zh_vocab{}".format(num_examples))

    train_iter = loaddata(dataset, 64)
    print(len(src_vocab))

    # 初始化，词表大小，kqv，hidden，
    # encoder = TransformerEncoder(
    #     len(src_vocab), key_size, query_size, value_size, num_hiddens,
    #     norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    #     num_layers, dropout)
    # decoder = TransformerDecoder(
    #     len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
    #     norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
    #     num_layers, dropout)
    encoder = torch.load(MODEL_ROOT + "trans_encoder{}.mdl".format(num_examples))
    decoder = torch.load(MODEL_ROOT + "trans_decoder{}.mdl".format(num_examples))

    net = EncoderDecoder(encoder, decoder)
    train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device, first_train=False)

    torch.save(encoder, MODEL_ROOT + "trans_encoder{}.mdl".format(num_examples))
    torch.save(decoder, MODEL_ROOT + "trans_decoder{}.mdl".format(num_examples))

